"""
 Module: Post processing of results - combining rankings from different methods,
 and preparing characteristics data for the outlier hospitals.
 """
__author__ = "shubhranshu-shekhar"
__date__ = "10/25/23"


import dill
import pandas as pd
import bz2, pickle
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import json
import rankaggregation as ra

try:
    from data_util import parse_args, mkdir_p, NumpyEncoder, read_data, read_chronic_conditions_data
except ImportError as e:
    from .data_util import parse_args, mkdir_p, NumpyEncoder, read_data, read_chronic_conditions_data


def combine_rankings():
    # Subspace results
    # load rankings from different methods
    with bz2.BZ2File('../output/subspace_module.bz.dill', 'rb') as f:
        [methods, models, scores] = dill.load(f)
    
    # load jaccard df
    with bz2.BZ2File('../data/prvdr_icd10_jaccard_df.bz2.pkl', 'rb') as f:
        df_jaccard = pickle.load(f)

    prvdr_lst_subspace = list(df_jaccard.index) # providers order in subspace models



    provider_subspace_df = pd.DataFrame({
        'PRVDR_NUM': prvdr_lst_subspace,
        'sod': scores[0],
        'ifr': scores[1],
        'loda': scores[2],
        'rshash': scores[3],
        'rcf': scores[4],
    })
    provider_subspace_df.head()


    sod_ranks = (provider_subspace_df[['PRVDR_NUM', 'sod']].sort_values(
        by=['sod'], ascending=False))['PRVDR_NUM'].tolist()
    ifr_ranks = (provider_subspace_df[['PRVDR_NUM', 'ifr']].sort_values(
        by=['ifr'], ascending=False))['PRVDR_NUM'].tolist()
    loda_ranks = (provider_subspace_df[['PRVDR_NUM', 'loda']].sort_values(
        by=['loda'], ascending=False))['PRVDR_NUM'].tolist()
    rshash_ranks = (provider_subspace_df[['PRVDR_NUM', 'rshash']].sort_values(
        by=['rshash'], ascending=False))['PRVDR_NUM'].tolist()
    rcf_ranks = (provider_subspace_df[['PRVDR_NUM', 'rcf']].sort_values(
        by=['rcf'], ascending=False))['PRVDR_NUM'].tolist()

    # peer model ranking
    with bz2.BZ2File('../output/peer_excess_amount.bz2.pkl', 'rb') as f:
        provider_excess_dct = pickle.load(f)
    
    peer_ranking =  sorted(provider_excess_dct.keys(), key=lambda x: x[1], reverse=True)

    # Regression results
    with bz2.BZ2File('../output/reg.pkl', 'rb') as f:
        reg = pickle.load(f)


    # load reg data so that providers can be ranked as per their coefficients
    with bz2.BZ2File('../data/sum_mat_target.bz2.pkl', 'rb') as f:
        [_, _, bene_prvdr_col_map] = pickle.load(f)

    bene_prvdr_col_map_reverse = {v:k for k, v in bene_prvdr_col_map.items()}
    prvdr_col_list = [bene_prvdr_col_map_reverse[i]
                        for i in range(len(bene_prvdr_col_map_reverse))]

    print(len(bene_prvdr_col_map.items()))
    provider_coefs = reg.coef_[-2207:]

    provider_coef_df = pd.DataFrame({
        'PRVDR_NUM': prvdr_col_list,
        'coef': provider_coefs
    })
    provider_coef_df.head()
    provider_coef_df = provider_coef_df.sort_values(by=['coef'], ascending=False)

    # regression ranking
    regression_ranks = provider_coef_df['PRVDR_NUM'].tolist()

    

    agg = ra.RankAggregator()
    combined_ranking = agg.instant_runoff(
        [regression_ranks, sod_ranks, ifr_ranks, loda_ranks, rshash_ranks, peer_ranking])

    # save the combined ranking
    with bz2.BZ2File('../results/ensemble_ranking.bz2.pkl', 'wb') as f:
        pickle.dump(combined_ranking, f)


def prepare_charateristics_data(DOJ=True):
    # load inpatient data
    args = parse_args()
    args.savepath = '../output/'
    PCT = '100pct'  # '0001pct'
    args.year = "2017"
    args.pct = PCT

    print("Reading inpatient data...")
    df = read_data(args, "2017")
    df['los'] = (df['thru_dt'] - df['from_dt']).dt.days
    print("Shape of inpatient data: ", df.shape)

    prvdr_los_dct = df[['provider', 'los']].groupby('provider')['los'].agg('mean').to_dict()
    prvdr_bene_count_dct = df[['provider', 'bene_id']].groupby('provider')['bene_id'].nunique().to_dict()

    # load provider of service files
    pos_df = pd.read_csv('../data/pos_other_sep20.csv', header=0, on_bad_lines='skip', encoding_errors='ignore', 
                     dtype=str)    
    pos_df = pos_df[['PRVDR_NUM', 'GNRL_CNTL_TYPE_CD', 'PRVDR_CTGRY_CD']]

    pos_df = pos_df.dropna()
    pos_df = pos_df.drop_duplicates()
    pos_df = pos_df[pos_df['PRVDR_CTGRY_CD'] == '01']

    hosp_type_dict = {'01': 'P',
                    '02': 'P',
                    '03': 'P',
                    '04': 'N',
                    '05': 'N',
                    '06': 'N',
                    '07': 'G',
                    '08': 'G',
                    '09': 'G',
                    '10': 'G',
                    '11': 'G',
                    '12': 'G',
                    '13': 'P',
                    }
    pos_df['PRVDR_TYPE'] = pos_df['GNRL_CNTL_TYPE_CD'].map(hosp_type_dict)
    
    hospital_type = dict(zip(pos_df['PRVDR_NUM'], pos_df['PRVDR_TYPE']))
    hospital_type = {k:v for k, v in hospital_type.items() if v}

    ### load geberal hospital info
    gen_info_df = pd.read_csv('../data/Hospital_General_Information.csv', header=0)

    ownership_dct = {'Government - Hospital District or Authority': 'G', 'Proprietary': 'P',
                    'Voluntary non-profit - Private':'P', 'Government - State': 'G',
                    'Voluntary non-profit - Other': 'NP', 'Government - Local': 'G',
                    'Voluntary non-profit - Church': 'NP', 'Government - Federal': 'G', 'Tribal': 'T',
                    'Department of Defense': 'G', 'Physician': 'P'}
    
    gen_info_df = gen_info_df[['Facility ID', 'Facility Name', 'State', 'ZIP Code', 'Hospital Type', 
                          'Hospital Ownership', 'Hospital overall rating']]
    gen_info_df['owner-type'] = gen_info_df['Hospital Ownership'].map(ownership_dct)
    gen_info_df['alos'] = gen_info_df['Facility ID'].map(prvdr_los_dct)
    # gen_info_df['tot-revenue'] = gen_info_df['Facility ID'].map(prvdr_revenue_dct)
    # gen_info_df['claims'] = gen_info_df['Facility ID'].map(prvdr_claim_counts_dct)
    gen_info_df['uniq_patients'] = gen_info_df['Facility ID'].map(prvdr_bene_count_dct)


    #### loading data for URBAN/RURAL
    service_df = pd.read_csv('../data/MUP_IHP_RY21_P02_V10_DY19_PrvSvc_0.csv')
    loc_map_dct = {'Metropolitan area core: primary flow within an urbanized area of 50,000 and greater': 'urban',
            'Micropolitan area core: primary flow within an urban cluster of 10,000 to 49,999': 'urban',
            'Small town core: primary flow within an urban cluster of 2,500 to 9,999': 'town',
            'Metropolitan area low commuting: primary flow 10% to <30% to a urbanized area of 50,000 and greater': 'urban',
            'Small town low commuting: primary flow 10% to <30% to a urban cluster of 2,500 to 9,999': 'town',
            'Secondary flow 30% to <50% to a urbanized area of 50,000 and greater': 'urban',
            'Micropolitan high commuting: primary flow 30% or more to a urban cluster of 10,000 to 49,999': 'urban',
            'Rural areas: primary flow to a tract outside a urbanized area of 50,000 and greater or UC': 'rural',
            'Metropolitan area high commuting: primary flow 30% or more to a urbanized area of 50,000 and greater': 'urban',
            'Secondary flow 30% to <50% to a larger urbanized area of 50,000 and greater': 'urban',
            'Unknown': 'unknown',
            'Micropolitan low commuting: primary flow 10% to <30% to a urban cluster of 10,000 to 49,999': 'urban',
            'Small town high commuting: primary flow 30% or more to a urban cluster of 2,500 to 9,999': 'town',
            'Secondary flow 30% to <50% to a urban cluster of 10,000 to 49,999': 'urban'}

    # select relevant columns
    service_df = service_df[['Rndrng_Prvdr_CCN', 
                             'Rndrng_Prvdr_RUCA_Desc']].drop_duplicates().reset_index(drop=True)
    service_df['loc_type'] = service_df['Rndrng_Prvdr_RUCA_Desc'].map(loc_map_dct)
    prvdr_loc_dct = dict(zip(service_df['Rndrng_Prvdr_CCN'].astype(str), service_df['loc_type']))

    # add to general info df
    gen_info_df['loc_type'] = gen_info_df['Facility ID'].map(prvdr_loc_dct)

    # regression on characteristics
    reg_df = gen_info_df.dropna().reset_index(drop=True)
    reg_df.columns = ['FacilityID', 'FacilityName', 'State', 'ZIPCode', 'HospitalType',
        'HospitalOwnership', 'HospitalOverallRating', 'OwnerType', 'Alos',
        'UniqPatients', 'LocType']
    

    print("Columns of reg_df: ", reg_df.columns)
    reg_df['OwnerType'] = reg_df['HospitalOwnership'].map(ownership_dct)


    if DOJ:
        # get outliers list ground truth
        ### load outlier list based on news articles
        df1 = pd.read_csv('../data/outlier_ccn.csv')
        df2 = pd.read_csv('../data/oa_outlier_ccn.csv')
        df = pd.concat([df1, df2]).reset_index(drop=True)
        
        outlier_ccn_dct = dict(zip(df['CCN'].astype(str), df['name']))
        outlier_df = reg_df[reg_df['FacilityID'].isin(list(outlier_ccn_dct.keys()))].reset_index(drop=True)

        
    else:
        # get outliers list from FLAGGED hospitals
        # Use that to do the characteristics plots
        print("Working with top 100 outliers based on EM ranking...")
        outlier_df = None
        with bz2.BZ2File("../results/ensemble_ranking.bz2.pkl", 'rb') as f:
            em_ranking = pickle.load(f)
        rank_dct = {}
        for i, h in enumerate(em_ranking):
            if h in rank_dct:
                continue
            rank_dct[h] = i+1
        
        reg_df['rank'] = reg_df['FacilityID'].map(rank_dct)
        outlier_df = reg_df[reg_df['rank'] <= 110].reset_index(drop=True)
    
    # filter outlier df to CCNs that are contained in ref data of 2207 hospitals
    ref_prvdr_data = pd.read_csv('../output/hospitals_type1_subtype1_inpatient.csv')
    ref_prvdr_data['PRVDR_NUM'] = ref_prvdr_data['PRVDR_NUM'].astype(str)


    # get the outlier df that are in the reference data
    print("Shape of outlier_df before: ", outlier_df.shape)
    outlier_df = outlier_df[outlier_df['FacilityID'].isin(ref_prvdr_data['PRVDR_NUM'])].reset_index(drop=True)
    print("Shape of outlier_df after: ", outlier_df.shape)

    print("Shape of reg_df before: ", reg_df.shape)
    reg_df = reg_df[reg_df['FacilityID'].isin(ref_prvdr_data['PRVDR_NUM'])].reset_index(drop=True)
    print("Shape of reg_df after: ", reg_df.shape)


    # save the following plot
    plt.figure(figsize=(8, 3))
    order = ['1','2','3','4', '5']
    outlier_df['HospitalOverallRating'] = pd.Categorical(outlier_df['HospitalOverallRating'], order)
    ax = sns.histplot(x='HospitalOverallRating', data=outlier_df, stat="density", label='DOJ', color="orange", 
                    alpha=1)
    # other_df['HospitalOverallRating'] = pd.Categorical(other_df['HospitalOverallRating'], order)
    # ax = sns.histplot(x='HospitalOverallRating', data=other_df, stat="density", label='Normal', color="g", alpha=0.2)
    tmp_df = reg_df[~(reg_df['HospitalOverallRating'] == 'Not Available')]
    tmp_df['HospitalOverallRating'] = pd.Categorical(tmp_df['HospitalOverallRating'], order)
    ax = sns.histplot(x='HospitalOverallRating', data=tmp_df, 
                    stat="density", label='All', color="blue", alpha=0.2)

    all_heights = np.array([h.get_height() for h in ax.patches])
    num_bars = len(all_heights)//3
    out_bin_h_rating = all_heights[:num_bars]
    other_bin_h_rating = all_heights[num_bars:2*num_bars]
    all_bin_h_rating = all_heights[2*num_bars:]

    plt.legend()
    # save the figure
    plt.savefig('../output/hospital_rating_distribution.pdf', bbox_inches='tight')
    # plt.show()

    # Next one:
    order = list(outlier_df['State'].value_counts().index)
    outlier_df['State'] = pd.Categorical(outlier_df['State'], order)

    all_order = order.copy()
    tmp_order = list(reg_df['State'].value_counts().index)
    for o_ in tmp_order:
        if o_ not in all_order:
            all_order.append(o_)
    reg_df['State'] = pd.Categorical(reg_df['State'], all_order)

            
    plt.figure(figsize=(20, 3))
    ax = sns.histplot(x='State', data=outlier_df, stat="density", label='DOJ', color="orange", 
                    alpha=1)
    ax = sns.histplot(x='State', data=reg_df, stat="density", label='All', color="blue", alpha=0.2)
    plt.xlabel('State', fontsize=14)
    plt.ylabel('Density', fontsize=14)
    plt.legend(fontsize=14)
    # save the figure
    plt.savefig('../output/state_distribution.pdf', bbox_inches='tight')

    #########Next - ownership type
    plt.figure(figsize=(5, 3))
    ax = sns.histplot(x='OwnerType', data=outlier_df, stat="density", label='DOJ', color="orange", 
                    alpha=1)
    ax = sns.histplot(x='OwnerType', data=reg_df[~(reg_df['OwnerType'] == 'T')], stat="density", label='All', 
                    color="blue", alpha=0.2)
    plt.legend()
    # save the figure
    plt.savefig('../output/ownership_distribution.pdf', bbox_inches='tight')

    #########Next - location type
    loc_dct = {'urban': 'Urban', 'rural': 'Non Urban', 'town': 'Non Urban', 'unknown': 'Non Urban'}
    reg_df['Location'] = reg_df['LocType'].map(loc_dct)
    outlier_df['Location'] = outlier_df['LocType'].map(loc_dct)

    plt.figure(figsize=(3.5, 3))
    ax = sns.histplot(x='Location', data=outlier_df, stat="density", label='DOJ', color="orange", 
                    alpha=1)
    ax = sns.histplot(x='Location', data=reg_df, stat="density", label='All', color="blue", alpha=0.2)
    plt.legend()
    # save the figure
    plt.savefig('../output/location_distribution.pdf', bbox_inches='tight')

    ###########Next Density ALOS
    plt.figure(figsize=(8, 3))
    ax = sns.distplot(outlier_df['Alos'], label='DOJ', color="orange")
    ax = sns.distplot(reg_df['Alos'], label='All', color="blue",  ax=ax)
    # plt.tick_params(axis='both', which='major', labelsize=14)
    plt.legend()
    # save the figure
    plt.savefig('../output/alos_distribution.pdf', bbox_inches='tight')


    ##############Next - Uniq Patients
    plt.figure(figsize=(8, 3))
    ax = sns.distplot(outlier_df['UniqPatients'], label='DOJ', color="orange")
    ax = sns.distplot(reg_df['UniqPatients'], label='All', color="blue",  ax=ax)
    # all_bin_h_own = np.array([h.get_height() for h in ax.patches])[num_bars:]

    plt.legend()
    # save the figure
    plt.savefig('../output/uniq_patients_distribution.pdf', bbox_inches='tight')

def count_DOJ_short_stay_hospitals():
    # get outliers list ground truth
    ### load outlier list based on news articles
    df1 = pd.read_csv('../data/outlier_ccn.csv')
    df2 = pd.read_csv('../data/oa_outlier_ccn.csv')
    df = pd.concat([df1, df2]).reset_index(drop=True)
    df['CCN'] = df['CCN'].astype(str)

    # drop duplicates
    df = df.drop_duplicates(subset=['CCN']).reset_index(drop=True)


    # filter outlier df to CCNs that are contained in ref data of 2207 hospitals
    ref_prvdr_data = pd.read_csv('../output/hospitals_type1_subtype1_inpatient.csv')
    ref_prvdr_data['PRVDR_NUM'] = ref_prvdr_data['PRVDR_NUM'].astype(str)

    
    # get the outlier df that are in the reference data
    print("Shape of CCN df before: ", df.shape)
    df = df[df['CCN'].isin(ref_prvdr_data['PRVDR_NUM'])].reset_index(drop=True)
    print("Shape of CCN after: ", df.shape)

    # save to csv sorted by CCN
    df = df.sort_values(by='CCN').reset_index(drop=True)
    df[['CCN', 'name']].to_csv('../data/groundtruth_outlier_ccn.csv', index=False)


if __name__ == '__main__':
    combine_rankings()
    prepare_charateristics_data(DOJ=False)
    count_DOJ_short_stay_hospitals()
